#!/bin/bash
export CUDA_DEVICE_MAX_CONNECTIONS=1

export CUDA_DEVICE_MAX_CONNECTIONS=1
export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3
export LD_LIBRARY_PATH=/usr/lib64:$LD_LIBRARY_PATH:/usr/local/lib:/usr/local/lib:/root/miniconda3/lib
export HCCL_CONNECT_TIMEOUT=1200
export COMBINED_ENABLE=1

source /usr/local/Ascend/ascend-toolkit/set_env.sh
source /usr/local/Ascend/nnal/atb/set_env.sh 

MODEL_TAG=MindSpeed_Infinity_Instruct_7M_3M_5e-7
# modify script model path and tokenizer path
TOKENIZER_PATH=./model_from_hf/Llama-2-7b-hf  #tokenizer path
TOKENIZER_MODEL=${TOKENIZER_PATH}/tokenizer.model 
# CHECKPOINT=./outputs/Llama-2-7b/tuned_with_MindSpeed_Infinity_Instruct_7M_3M_5e-7/1x1/iter_${CKPT_STEP}/
CHECKPOINT=./outputs/Llama-2-7b/tuned_with_${MODEL_TAG}/ckpt


# please fill these path configurations


# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6023
NNODES=1
NODE_RANK=0
NPUS_PER_NODE=4
WORLD_SIZE=$(($NPUS_PER_NODE*$NNODES))
TP=1
PP=4

DISTRIBUTED_ARGS="--nproc_per_node $NPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"

python -m torch.distributed.launch $DISTRIBUTED_ARGS inference.py \
       --tensor-model-parallel-size ${TP}  \
       --pipeline-model-parallel-size ${PP}  \
       --use-mcore-models \
       --task chat \
       --prompt-type llama2 \
       --use-kv-cache \
       --use-flash-attn \
       --top-p 0.6 \
       --temperature 0.8 \
       --use-fused-swiglu \
       --use-fused-rmsnorm \
       --use-fused-rotary-pos-emb \
       --num-layers 32 \
       --hidden-size 4096  \
       --ffn-hidden-size 11008 \
       --position-embedding-type rope \
       --seq-length 8192 \
       --max-new-tokens 512 \
       --micro-batch-size 1 \
       --global-batch-size 4 \
       --num-attention-heads 32  \
       --max-position-embeddings 8192 \
       --swiglu \
       --load "${CHECKPOINT}"  \
       --tokenizer-type PretrainedFromHF  \
       --tokenizer-name-or-path "${TOKENIZER_PATH}" \
       --tokenizer-model "${TOKENIZER_MODEL}"  \
       --tokenizer-not-use-fast \
       --fp16 \
       --normalization RMSNorm \
       --untie-embeddings-and-output-weights \
       --disable-bias-linear \
       --attention-softmax-in-fp32 \
       --no-load-optim \
       --no-load-rng \
       --no-masked-softmax-fusion \
       --no-gradient-accumulation-fusion \
       --exit-on-missing-checkpoint \
       --make-vocab-size-divisible-by 1 \
       | tee logs/chat_${MODEL_TAG}.log


